// Copyright (c) 2015-2019 Amanieu d'Antras and DGuco(杜国超).All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

#include "internal.h"

// For GetProcAddress and GetModuleHandle
#ifdef _WIN32
#include <windows.h>
#endif

// for pthread thread_local emulation
#if defined(EMULATE_PTHREAD_THREAD_LOCAL)
# include <pthread.h>
#endif

namespace stagefuture
{
namespace detail
{

// Per-thread data, aligned to cachelines to avoid false sharing
struct LIBASYNC_CACHELINE_ALIGN thread_data_t
{
    work_steal_queue queue;
    std::minstd_rand rng;
    task_wait_event event;
    std::thread handle;
};

// Internal data used by threadpool_scheduler
struct threadpool_data
{
    ~threadpool_data()
    {
    }

    threadpool_data(std::size_t num_threads)
        : thread_data(num_threads), shutdown(false), num_waiters(0), waiters(new task_wait_event *[num_threads])
    {}

    threadpool_data(std::size_t num_threads, std::function<void()> &&prerun_, std::function<void()> &&postrun_)
        : thread_data(num_threads), shutdown(false), num_waiters(0), waiters(new task_wait_event *[num_threads]),
          prerun(std::move(prerun_)), postrun(std::move(postrun_))
    {}

    // Mutex protecting everything except thread_data
    std::mutex lock;

    // Array of per-thread data
    aligned_array<thread_data_t> thread_data;

    // Global queue for tasks from outside the pool
    fifo_queue public_queue;

    // Shutdown request indicator
    bool shutdown;

    // List of threads waiting for tasks to run. num_waiters needs to be atomic
    // because it is sometimes read outside the mutex.
    std::atomic<std::size_t> num_waiters;
    std::unique_ptr<task_wait_event *[]> waiters;

    // Pre/Post run functions.
    std::function<void()> prerun;
    std::function<void()> postrun;

#ifdef BROKEN_JOIN_IN_DESTRUCTOR
    // Shutdown complete event, used instead of thread::join()
    std::size_t shutdown_num_threads;
    std::condition_variable shutdown_complete_event;
#endif
};

// this wrapper encapsulates both the owning_threadpool pointer and the thread id.
// this is done to improve performance on the emulated thread_local reducing the number
// of calls to "pthread_getspecific"
struct threadpool_data_wrapper
{
    threadpool_data *owning_threadpool;
    std::size_t thread_id;

    threadpool_data_wrapper(threadpool_data *owning_threadpool, std::size_t thread_id)
        :
        owning_threadpool(owning_threadpool), thread_id(thread_id)
    {}
};

#if defined(EMULATE_PTHREAD_THREAD_LOCAL)
struct pthread_emulation_threadpool_data_initializer {
    pthread_key_t key;

    pthread_emulation_threadpool_data_initializer()
    {
        pthread_key_create(&key, [](void* wrapper_ptr) {
            threadpool_data_wrapper* wrapper = static_cast<threadpool_data_wrapper*>(wrapper_ptr);
            delete wrapper;
        });
    }

    ~pthread_emulation_threadpool_data_initializer()
    {
        pthread_key_delete(key);
    }
};

static pthread_key_t get_local_threadpool_data_key()
{
    static pthread_emulation_threadpool_data_initializer initializer;
    return initializer.key;
}

#else

// Thread pool this thread belongs to, or null if not in pool
static THREAD_LOCAL threadpool_data *owning_threadpool = nullptr;

// Current thread's index in the pool
static THREAD_LOCAL std::size_t thread_id;

#endif

static void create_threadpool_data(threadpool_data *owning_threadpool_, std::size_t thread_id_)
{
#if defined(EMULATE_PTHREAD_THREAD_LOCAL)
    // the memory allocated here gets deallocated by the lambda declared on the key creation
    pthread_setspecific(get_local_threadpool_data_key(), new threadpool_data_wrapper(owning_threadpool_, thread_id_));
#else
    owning_threadpool = owning_threadpool_;
    thread_id = thread_id_;
#endif
}

static threadpool_data_wrapper get_threadpool_data_wrapper()
{
#if defined(EMULATE_PTHREAD_THREAD_LOCAL)
    threadpool_data_wrapper* wrapper = static_cast<threadpool_data_wrapper*>(pthread_getspecific(get_local_threadpool_data_key()));
    if(wrapper == nullptr) {
        // if, for some reason, the wrapper is not set, this won't cause a crash
        return threadpool_data_wrapper(nullptr, 0);
    }
    return *wrapper;
#else
    return threadpool_data_wrapper(owning_threadpool, thread_id);
#endif
}

// Try to steal a task from another thread's queue
static task_run_handle steal_task(threadpool_data *impl, std::size_t thread_id)
{
    // Make a list of victim thread ids and shuffle it
    std::vector<std::size_t> victims(impl->thread_data.size());
    std::iota(victims.begin(), victims.end(), 0);
    std::shuffle(victims.begin(), victims.end(), impl->thread_data[thread_id].rng);

    // Try to steal from another thread
    for (std::size_t i: victims) {
        // Don't try to steal from ourself
        if (i == thread_id)
            continue;

        if (task_run_handle t = impl->thread_data[i].queue.steal())
            return t;
    }

    // No tasks found, but we might have missed one if it was just added. In
    // practice this doesn't really matter since it will be handled by another
    // thread.
    return task_run_handle();
}

// Main task stealing loop which is used by worker threads when they have
// nothing to do.
static void thread_task_loop(threadpool_data *impl, std::size_t thread_id, task_wait_handle wait_task)
{
    // Get our thread's data
    thread_data_t &current_thread = impl->thread_data[thread_id];

    // Flag indicating if we have added a continuation to the task
    bool added_continuation = false;

    // Loop while waiting for the task to complete
    while (true) {
        // Check if the task has finished. If we have added a continuation, we
        // need to make sure the event has been signaled.
        if (wait_task
            && (added_continuation ? current_thread.event.try_wait(wait_type::task_finished) : wait_task.ready()))
            return;

        // Try to get a task from the local queue
        if (task_run_handle t = current_thread.queue.pop()) {
            t.run();
            continue;
        }

        // Stealing loop
        while (true) {
            // Try to steal a task
            if (task_run_handle t = steal_task(impl, thread_id)) {
                t.run();
                break;
            }

            // Try to fetch from the public queue
            std::unique_lock<std::mutex> locked(impl->lock);
            if (task_run_handle t = impl->public_queue.pop()) {
                // Don't hold the lock while running the task
                locked.unlock();
                t.run();
                break;
            }

            // If shutting down and we don't have a task to wait for, return.
            if (!wait_task && impl->shutdown) {
#ifdef BROKEN_JOIN_IN_DESTRUCTOR
                // Notify once all worker threads have exited
                if (--impl->shutdown_num_threads == 0)
                    impl->shutdown_complete_event.notify_one();
#endif
                return;
            }

            // No tasks found, so sleep until something happens.
            // If a continuation has not been added yet, add it.
            if (wait_task && !added_continuation) {
                // Create a continuation for the task we are waiting for
                task_wait_event &event = current_thread.event;
                wait_task.on_finish([&event]
                                    {
                                        // Signal the thread's event
                                        event.signal(wait_type::task_finished);
                                    });
                added_continuation = true;
            }

            // Add our thread to the list of waiting threads
            size_t num_waiters_val = impl->num_waiters.load(std::memory_order_relaxed);
            impl->waiters[num_waiters_val] = &current_thread.event;
            impl->num_waiters.store(num_waiters_val + 1, std::memory_order_relaxed);

            // Wait for our event to be signaled when a task is scheduled or
            // the task we are waiting for has completed.
            locked.unlock();
            int events = current_thread.event.wait();
            locked.lock();

            // Remove our thread from the list of waiting threads
            num_waiters_val = impl->num_waiters.load(std::memory_order_relaxed);
            for (std::size_t i = 0; i < num_waiters_val; i++) {
                if (impl->waiters[i] == &current_thread.event) {
                    if (i != num_waiters_val - 1)
                        std::swap(impl->waiters[i], impl->waiters[num_waiters_val - 1]);
                    impl->num_waiters.store(num_waiters_val - 1, std::memory_order_relaxed);
                    break;
                }
            }

            // Check again if the task has finished
            if (wait_task && (events & wait_type::task_finished))
                return;
        }
    }
}

// Wait for a task to complete (for worker threads inside thread pool)
static void threadpool_wait_handler(task_wait_handle wait_task)
{
    threadpool_data_wrapper wrapper = get_threadpool_data_wrapper();
    thread_task_loop(wrapper.owning_threadpool, wrapper.thread_id, wait_task);
}

// Worker thread main loop
static void worker_thread(threadpool_data *owning_threadpool, std::size_t thread_id)
{
    // store on the local thread data
    create_threadpool_data(owning_threadpool, thread_id);

    // Set the wait handler so threads from the pool do useful work while
    // waiting for another task to finish.
    set_thread_wait_handler(threadpool_wait_handler);

    // Seed the random number generator with our id. This gives each thread a
    // different steal order.
    owning_threadpool->thread_data[thread_id].rng.seed(static_cast<std::minstd_rand::result_type>(thread_id));

    // Prerun hook
    if (owning_threadpool->prerun) owning_threadpool->prerun();

    // Main loop, runs until the shutdown signal is recieved
    thread_task_loop(owning_threadpool, thread_id, task_wait_handle());

    // Postrun hook
    if (owning_threadpool->postrun) owning_threadpool->postrun();
}

// Recursive function to supply_async all worker threads in parallel
static void recursive_spawn_worker_thread(threadpool_data *impl, std::size_t index, std::size_t threads)
{
    // If we are down to one thread, go to the worker main loop
    if (threads == 1)
        worker_thread(impl, index);
    else {
        // Split thread range into 2 sub-ranges
        std::size_t mid = index + threads / 2;

        // Spawn a thread for half of the range
        impl->thread_data[mid].handle = std::thread(recursive_spawn_worker_thread, impl, mid, threads - threads / 2);
#ifdef BROKEN_JOIN_IN_DESTRUCTOR
        impl->thread_data[mid].handle.detach();
#endif

        // Tail-recurse to handle our half of the range
        recursive_spawn_worker_thread(impl, index, threads / 2);
    }
}

} // namespace detail

threadpool_scheduler::threadpool_scheduler(threadpool_scheduler &&other)
    : impl(std::move(other.impl))
{
}

threadpool_scheduler::threadpool_scheduler(std::size_t num_threads)
    : impl(new detail::threadpool_data(num_threads))
{
    // Start worker threads
    impl->thread_data[0].handle = std::thread(detail::recursive_spawn_worker_thread, impl.get(), 0, num_threads);
#ifdef BROKEN_JOIN_IN_DESTRUCTOR
    impl->thread_data[0].handle.detach();
#endif
}

threadpool_scheduler::threadpool_scheduler(std::size_t num_threads,
                                           std::function<void()> &&prerun,
                                           std::function<void()> &&postrun)
    : impl(new detail::threadpool_data(num_threads, std::move(prerun), std::move(postrun)))
{
    // Start worker threads
    impl->thread_data[0].handle = std::thread(detail::recursive_spawn_worker_thread, impl.get(), 0, num_threads);
#ifdef BROKEN_JOIN_IN_DESTRUCTOR
    impl->thread_data[0].handle.detach();
#endif
}

// Wait for all currently running tasks to finish
threadpool_scheduler::~threadpool_scheduler()
{
    if (!impl) return;
#ifdef _WIN32
    // Windows kills all threads except one on process exit before calling
    // global destructors in DLLs. Waiting for dead threads to exit will likely
    // result in deadlocks, so we just exit early if we detect that the process
    // is exiting.
    auto RtlDllShutdownInProgress = reinterpret_cast<BOOLEAN(WINAPI *)()>(GetProcAddress(GetModuleHandleW(L"ntdll.dll"), "RtlDllShutdownInProgress"));
    if (RtlDllShutdownInProgress && RtlDllShutdownInProgress()) {
# ifndef BROKEN_JOIN_IN_DESTRUCTOR
        // We still need to detach the thread handles otherwise the std::thread
        // destructor will throw an exception.
        for (std::size_t i = 0; i < impl->thread_data.size(); i++) {
            try {
                impl->thread_data[i].handle.detach();
            } catch (...) {}
        }
# endif
        return;
    }
#endif

    {
        std::unique_lock<std::mutex> locked(impl->lock);

        // Signal shutdown
        impl->shutdown = true;

        // Wake up any sleeping threads
        size_t num_waiters_val = impl->num_waiters.load(std::memory_order_relaxed);
        for (std::size_t i = 0; i < num_waiters_val; i++)
            impl->waiters[i]->signal(detail::wait_type::task_available);
        impl->num_waiters.store(0, std::memory_order_relaxed);

#ifdef BROKEN_JOIN_IN_DESTRUCTOR
        // Wait for the threads to exit
        impl->shutdown_num_threads = impl->thread_data.size();
        impl->shutdown_complete_event.wait(locked);
#endif
    }

#ifndef BROKEN_JOIN_IN_DESTRUCTOR
    // Wait for the threads to exit
    for (std::size_t i = 0; i < impl->thread_data.size(); i++)
        impl->thread_data[i].handle.join();
#endif
}

// Schedule a task on the thread pool
void threadpool_scheduler::schedule(task_run_handle t)
{
    detail::threadpool_data_wrapper wrapper = detail::get_threadpool_data_wrapper();

    // Check if we are in the thread pool
    if (wrapper.owning_threadpool == impl.get()) {
        // Push the task onto our task queue
        impl->thread_data[wrapper.thread_id].queue.push(std::move(t));

        // If there are no sleeping threads, just return. We check outside the
        // lock to avoid locking overhead in the fast path.
        if (impl->num_waiters.load(std::memory_order_relaxed) == 0)
            return;

        // Get a thread to wake up from the list
        std::lock_guard<std::mutex> locked(impl->lock);

        // Check again if there are waiters
        size_t num_waiters_val = impl->num_waiters.load(std::memory_order_relaxed);
        if (num_waiters_val == 0)
            return;

        // Pop a thread from the list and wake it up
        impl->waiters[num_waiters_val - 1]->signal(detail::wait_type::task_available);
        impl->num_waiters.store(num_waiters_val - 1, std::memory_order_relaxed);
    }
    else {
        std::lock_guard<std::mutex> locked(impl->lock);

        // Push task onto the public queue
        impl->public_queue.push(std::move(t));

        // Wake up a sleeping thread
        size_t num_waiters_val = impl->num_waiters.load(std::memory_order_relaxed);
        if (num_waiters_val == 0)
            return;
        impl->waiters[num_waiters_val - 1]->signal(detail::wait_type::task_available);
        impl->num_waiters.store(num_waiters_val - 1, std::memory_order_relaxed);
    }
}

single_thread_scheduler::single_thread_scheduler()
    : threadpool_scheduler(1)
{}

single_thread_scheduler::single_thread_scheduler(std::function<void()> &&prerun_,
                                                 std::function<void()> &&postrun_)
    : threadpool_scheduler(1, std::move(prerun_), std::move(postrun_))
{

}
single_thread_scheduler::~single_thread_scheduler()
{

}
} // namespace stagefuture

#ifndef LIBASYNC_STATIC
#if defined(__GNUC__) && !defined(_WIN32)
# pragma GCC visibility pop
#endif
#endif
